package com.ld.shieldsb.canalclient.handler.impl.db;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

import javax.sql.DataSource;

import com.ld.shieldsb.canalclient.etl.AbstractEtlService;
import com.ld.shieldsb.canalclient.etl.EtlConsumer;
import com.ld.shieldsb.canalclient.handler.config.AdapterConfig.AdapterMapping;
import com.ld.shieldsb.canalclient.handler.config.MappingConfig;
import com.ld.shieldsb.canalclient.handler.config.MappingConfig.DbMapping;
import com.ld.shieldsb.canalclient.util.CanalUtil;
import com.ld.shieldsb.canalclient.util.SyncUtil;

/**
 * RDB ETL 操作业务类
 *
 * @author rewerma @ 2018-11-7
 * @version 1.0.0
 */
public class RdbEtlService extends AbstractEtlService {

    private DataSource targetDS;

    public RdbEtlService(DataSource targetDS, MappingConfig config) {
        super("RDB", config);
        this.targetDS = targetDS;
        this.config = config;
    }

    /**
     * 执行导入，先删除旧数据再插入新数据
     */
    @Override
    protected boolean executeSqlImport(DataSource srcDS, String sql, List<Object> values, long cnt, AdapterMapping mapping,
            AtomicLong impCount, List<String> errMsg, Consumer<EtlConsumer> con) {
        boolean result = false;
        try {
            DbMapping dbMapping = (DbMapping) mapping;
            // 列映射，目标为key，源字段为value，保留大小写
            Map<String, String> columnsMap = new LinkedHashMap<>();
            // 列类型，key忽略大小写
            Map<String, Integer> columnType = new LinkedHashMap<>();

            CanalUtil.sqlRS(targetDS, "SELECT * FROM " + SyncUtil.getDbTableName(dbMapping) + " LIMIT 1 ", rs -> {
                try {

                    ResultSetMetaData rsd = rs.getMetaData();
                    int columnCount = rsd.getColumnCount();
                    // 获取所有的列名
                    List<String> columns = new ArrayList<>();
                    for (int i = 1; i <= columnCount; i++) {
                        columnType.put(rsd.getColumnName(i).toLowerCase(), rsd.getColumnType(i));
                        columns.add(rsd.getColumnName(i));
                    }

                    columnsMap.putAll(SyncUtil.getColumnsMap(dbMapping, columns));
                    return true;
                } catch (Exception e) {
                    logger.error(e.getMessage(), e);
                    return false;
                }
            });
            // 执行查询
            CanalUtil.sqlRS(srcDS, sql, values, rs -> {
                int idx = 1;

                try {
                    boolean completed = false; // 是否完成

                    StringBuilder insertSql = new StringBuilder();
                    insertSql.append("INSERT INTO ").append(SyncUtil.getDbTableName(dbMapping)).append(" (");
                    columnsMap.forEach((targetColumnName, srcColumnName) -> insertSql.append(targetColumnName).append(","));

                    int len = insertSql.length();
                    insertSql.delete(len - 1, len).append(") VALUES (");
                    int mapLen = columnsMap.size();
                    for (int i = 0; i < mapLen; i++) {
                        insertSql.append("?,");
                    }
                    len = insertSql.length();
                    insertSql.delete(len - 1, len).append(")");
                    try (Connection connTarget = targetDS.getConnection();
                            PreparedStatement pstmt = connTarget.prepareStatement(insertSql.toString())) {
                        connTarget.setAutoCommit(false);

                        while (rs.next()) {
                            completed = false;

                            pstmt.clearParameters();

                            // 删除数据
                            Map<String, Object> pkVal = new LinkedHashMap<>();
                            StringBuilder deleteSql = new StringBuilder("DELETE FROM " + SyncUtil.getDbTableName(dbMapping) + " WHERE ");
                            appendCondition(dbMapping, deleteSql, pkVal, rs);
                            try (PreparedStatement pstmt2 = connTarget.prepareStatement(deleteSql.toString())) {
                                int k = 1;
                                for (Object val : pkVal.values()) {
                                    pstmt2.setObject(k++, val);
                                }
                                pstmt2.execute();
                            }

                            int i = 1;
                            for (Map.Entry<String, String> entry : columnsMap.entrySet()) {
                                String targetClolumnName = entry.getKey();
                                String srcColumnName = entry.getValue();
                                if (srcColumnName == null) {
                                    srcColumnName = targetClolumnName;
                                }

                                Integer type = columnType.get(targetClolumnName.toLowerCase());

                                Object value = rs.getObject(srcColumnName);
                                if (value != null) {
                                    SyncUtil.setPStmt(type, pstmt, value, i);
                                } else {
                                    pstmt.setNull(i, type);
                                }

                                i++;
                            }
                            // 提交数据，注意不是提交事务，事务未提交前还看不到
                            pstmt.execute();
                            if (logger.isTraceEnabled()) {
                                logger.trace("Insert into target table, sql: {}", insertSql);
                            }

                            if (idx % dbMapping.getCommitBatch() == 0) { // 临时提交
                                connTarget.commit();
                                completed = true;
                            }
                            idx++;
                            impCount.incrementAndGet();
                            if (logger.isDebugEnabled()) {
                                logger.debug("successful import count:" + impCount.get());
                            }
                        }
                        // 循环完后提交事务
                        if (!completed) {
                            connTarget.commit();
                        }
                    }

                } catch (Exception e) {
                    String msg = e.getMessage();
                    if (msg.contains("doesn't have a default value")) {
                        msg = msg.replace("Field", "字段").replace("doesn't have a default value", "不能为空，但是没有默认值！");
                    }
                    logger.error(dbMapping.getTable() + " etl 失败! ==>" + msg, e);
                    errMsg.add(dbMapping.getTable() + " etl 失败! ==>" + msg);
                }
                return idx;
            });
            result = true;
        } catch (Exception e) {
            logger.error(e.getMessage(), e);
        } finally {
            con.accept(EtlConsumer.builder().processState(EtlConsumer.PROCESS_STATE_END).srcDS(srcDS).sql(sql).values(values)
                    .mapping(mapping).impCount(impCount).errMsg(errMsg).success(result).count(cnt).build());
        }
        return result;
    }

    /**
     * 拼接目标表主键where条件
     */
    private static void appendCondition(DbMapping dbMapping, StringBuilder sql, Map<String, Object> values, ResultSet rs)
            throws SQLException {
        // 拼接主键
        for (Map.Entry<String, String> entry : dbMapping.getTargetPk().entrySet()) {
            String targetColumnName = entry.getKey();
            String srcColumnName = entry.getValue();
            if (srcColumnName == null) {
                srcColumnName = targetColumnName;
            }
            sql.append(targetColumnName).append("=? AND ");
            values.put(targetColumnName, rs.getObject(srcColumnName));
        }
        int len = sql.length();
        sql.delete(len - 4, len);
    }
}
